{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "4b90aa2d-bf3e-4d43-807d-87d65209d444",
   "metadata": {},
   "source": [
    "# 量子门与量子线路基础操作"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "b36749a3-4e3e-4fb3-9c13-f5a3a3794cd6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pylint: disable=W0104\n",
    "import sys\n",
    "import mindspore as ms\n",
    "\n",
    "path = '../'\n",
    "sys.path.append(path)                                # 添加自主开发的量子模拟器代码所在路径\n",
    "ms.set_context(mode=ms.PYNATIVE_MODE)                # 使用动态图\n",
    "# ms.set_context(mode=ms.GRAPH_MODE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "7a4d0663-1c29-43ec-8549-33a3087ad880",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from qusmoke.gates import H, X, Y, Z, RX, RY, RZ, CNOT, ZZ, SWAP, U1, U2, U3\n",
    "from qusmoke.circuit import Circuit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e209eb92-59dc-4573-a993-55424a49f2a0",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Circuit<\n",
      "  (gates): SequentialCell<\n",
      "    (0): H<0>\n",
      "    (1): X<0, [2, 3]>\n",
      "    (2): Y<1>\n",
      "    (3): Z<1, 2>\n",
      "    (4): RX(param)<3>\n",
      "    (5): RY(_param_)<2, 3>\n",
      "    (6): ZZ(_param_)<(0, 1)>\n",
      "    (7): SWAP<(1, 3)>\n",
      "    (8): U3(u3_params)<0>\n",
      "    (9): U3(u3_params)<1>\n",
      "    (10): U3(u3_params)<2>\n",
      "    (11): U3(u3_params)<3>\n",
      "    (12): X<1, [0, 2, 3]>\n",
      "    (13): RY(_param_)<3, [0, 1, 2]>\n",
      "    >\n",
      "  >\n"
     ]
    }
   ],
   "source": [
    "# 构建一个基本的线路，线路中的每个参数门均为 mindspore.nn.Cell 的子类\n",
    "\n",
    "circ = Circuit([\n",
    "    H(0),\n",
    "    X(0, [2,3]),\n",
    "    Y(1),\n",
    "    Z(1, 2),\n",
    "    RX('param').on(3),\n",
    "    RY(1.0).on(2, 3),\n",
    "    ZZ(2.0).on((0, 1)),\n",
    "    SWAP((1, 3)),\n",
    "    U3(1.0, 2.0, 3.0).on(0),\n",
    "    U3(1.0, 2.0, 3.0).on(1),\n",
    "    U3(1.0, 2.0, 3.0).on(2),\n",
    "    U3(1.0, 2.0, 3.0).on(3),\n",
    "    X(1, [0, 2, 3]),\n",
    "    RY(2.0).on(3, [0, 1, 2])\n",
    "])\n",
    "\n",
    "print(circ)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3f58659e-1bb9-4828-8f0c-208ee4132ece",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['param',\n",
       " '_param_',\n",
       " '_param_',\n",
       " 'u3_params',\n",
       " 'u3_params',\n",
       " 'u3_params',\n",
       " 'u3_params',\n",
       " '_param_']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 其中输出 '_param_' 为默认参数名\n",
    "circ.params_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e440de16-2ef2-4949-8326-4859e30435fa",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Circuit<\n",
      "  (gates): SequentialCell<\n",
      "    (0): H<0>\n",
      "    (1): X<0, [2, 3]>\n",
      "    (2): Y<1>\n",
      "    (3): Z<1, 2>\n",
      "    (4): RX(param)<3>\n",
      "    (5): RY(_param_)<2, 3>\n",
      "    (6): ZZ(_param_)<(0, 1)>\n",
      "    (7): SWAP<(1, 3)>\n",
      "    (8): U3(u3_params)<0>\n",
      "    (9): U3(u3_params)<1>\n",
      "    (10): U3(u3_params)<2>\n",
      "    (11): U3(u3_params)<3>\n",
      "    (12): X<1, [0, 2, 3]>\n",
      "    (13): RY(_param_)<3, [0, 1, 2]>\n",
      "    (14): RZ(param_z)<3>\n",
      "    (15): X<2>\n",
      "    (16): Y<0, 3>\n",
      "    (17): Z<1>\n",
      "    (18): Z<2>\n",
      "    >\n",
      "  >\n"
     ]
    }
   ],
   "source": [
    "# 支持线路等加法运算\n",
    "circ += RZ('param_z').on(3)\n",
    "circ += [X(2), Y(0, 3)]\n",
    "circ += Circuit([Z(1), Z(2)])\n",
    "\n",
    "print(circ)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ae391004-8807-4aae-9761-701f1406d9f6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "=================================================Circuit Summary=================================================\n",
      "|Total number of gates  : 19.                                                                                   |\n",
      "|Parameter gates        : 9.                                                                                    |\n",
      "|with 9 parameters are  :                                                                                       |\n",
      "|param, _param_, _param_, u3_params, u3_params, u3_params, u3_params, _param_, param_z                        . |\n",
      "|Number qubit of circuit: 4                                                                                     |\n",
      "=================================================================================================================\n"
     ]
    }
   ],
   "source": [
    "# 打印线路信息\n",
    "\n",
    "circ.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "14f172ea-b809-490a-b74e-82e82aecfa89",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.08451916-0.05577248j,  0.16271482+0.0295833j ,\n",
       "       -0.00641707-0.04917871j,  0.17163612-0.24420995j,\n",
       "        0.16867927+0.05835877j,  0.13320881+0.259294j  ,\n",
       "        0.13453737+0.00128568j,  0.15733941+0.1533909j ,\n",
       "        0.14758667+0.01421016j,  0.16689745+0.08064345j,\n",
       "        0.03495108+0.08881146j,  0.07131908-0.05617252j,\n",
       "        0.28291166-0.45243284j,  0.06628986+0.31992695j,\n",
       "        0.09377108-0.3911432j ,  0.12404367+0.21276072j], dtype=complex64)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 获取线路状态\n",
    "\n",
    "circ.get_qs()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "8360b9d4-b95b-4491-a5aa-80f6a1de3d25",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n_qubit: 4\n",
      "qs:\n",
      "[ 0.02713373-0.09186346j  0.13572592-0.07779382j -0.03236606-0.03037797j\n",
      " -0.02406999-0.2790749j   0.15023226-0.07741043j  0.2647003 +0.07822535j\n",
      "  0.13123125-0.07579522j  0.24045055+0.05868228j  0.09202915+0.09560809j\n",
      "  0.07232455+0.15972522j -0.02893956+0.08921152j  0.0811841 +0.00335759j\n",
      "  0.46028954-0.2083434j  -0.11549888+0.28698888j  0.40331292-0.20640226j\n",
      " -0.09110902+0.26201683j]\n"
     ]
    }
   ],
   "source": [
    "# 给可变参数赋予特定的值，对初始未给定名字的参数（默认名为 '_param_' ）不支持赋值\n",
    "qs = circ.get_qs({'param': 1.0, 'param_z': 2.0})\n",
    "\n",
    "print(f'n_qubit: {circ.n_qubit}')\n",
    "print(f'qs:\\n{qs}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "e05a0a1b-3999-4f05-9ce8-d34cc33205ad",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Parameter (name=4.param, shape=(), dtype=Float32, requires_grad=True),\n",
       " Parameter (name=5.param, shape=(), dtype=Float32, requires_grad=True),\n",
       " Parameter (name=6.param, shape=(), dtype=Float32, requires_grad=True),\n",
       " Parameter (name=8.param, shape=(3,), dtype=Float32, requires_grad=True),\n",
       " Parameter (name=9.param, shape=(3,), dtype=Float32, requires_grad=True),\n",
       " Parameter (name=10.param, shape=(3,), dtype=Float32, requires_grad=True),\n",
       " Parameter (name=11.param, shape=(3,), dtype=Float32, requires_grad=True),\n",
       " Parameter (name=13.param, shape=(), dtype=Float32, requires_grad=True),\n",
       " Parameter (name=14.param, shape=(), dtype=Float32, requires_grad=True)]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 打印线路可训练参数\n",
    "\n",
    "circ.trainable_params()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "38415944-73e8-4307-8029-443272f15807",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 将线路变为编码线路\n",
    "\n",
    "circ.as_encoder()\n",
    "# 或执行 circ.no_grad()，功能与 circ.as_encoder() 一致\n",
    "circ.trainable_params()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "6ad839b4-1d3e-4d11-94fe-1a98764f2c06",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "n_qubit: 4\n",
      "qs:\n",
      "[ 0.02713373-0.09186347j  0.13572591-0.07779382j -0.03236607-0.03037797j\n",
      " -0.02406997-0.27907495j  0.15023226-0.07741042j  0.26470028+0.07822533j\n",
      "  0.13123125-0.07579524j  0.2404506 +0.05868228j  0.09202917+0.0956081j\n",
      "  0.07232457+0.15972524j -0.02893956+0.08921154j  0.08118412+0.00335757j\n",
      "  0.46028954-0.20834341j -0.11549886+0.28698891j  0.40331302-0.20640233j\n",
      " -0.09110898+0.26201691j]\n"
     ]
    }
   ],
   "source": [
    "# 线路状态和 mindquantum 对比\n",
    "# 需要安装 mindquantum 0.8，可在 CPU 上运行进行对比\n",
    "\n",
    "import mindquantum as mq\n",
    "    \n",
    "mq_circ = mq.Circuit([\n",
    "    mq.H(0),\n",
    "    mq.X(0, [2,3]),\n",
    "    mq.Y(1),\n",
    "    mq.Z(1, 2),\n",
    "    mq.RX('param').on(3),\n",
    "    mq.RY(1.0).on(2, 3),\n",
    "    mq.ZZ(2.0).on((0, 1)),\n",
    "    mq.SWAP((1, 3)),\n",
    "    mq.U3(1.0, 2.0, 3.0).on(0),\n",
    "    mq.U3(1.0, 2.0, 3.0).on(1),\n",
    "    mq.U3(1.0, 2.0, 3.0).on(2),\n",
    "    mq.U3(1.0, 2.0, 3.0).on(3),\n",
    "    mq.X(1, [0, 2, 3]),\n",
    "    mq.RY(2.0).on(3, [0, 1, 2]),\n",
    "    mq.RZ('param_z').on(3),\n",
    "    mq.X(2),\n",
    "    mq.Y(0, 3),\n",
    "    mq.Z(1),\n",
    "    mq.Z(2)\n",
    "])\n",
    "\n",
    "mq_qs = mq_circ.get_qs(pr={'param': 1.0, 'param_z': 2.0})\n",
    "\n",
    "print(f'n_qubit: {mq_circ.n_qubits}')\n",
    "print(f'qs:\\n{mq_qs}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "4bbbef45-b64e-434a-ac51-ff4abf654b2c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ True  True  True  True  True  True  True  True  True  True  True  True\n",
      "  True  True  True  True]\n"
     ]
    }
   ],
   "source": [
    "# 对比可知，自定义的线路与 mindquantum 的输出一致\n",
    "\n",
    "print(np.isclose(qs, mq_qs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7cbd9b2c-1d6a-4452-b722-80e631a5e65d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "X<0>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 说明，sporequantum 中所有门均需要实例化，不支持 mindquantum 中的\n",
    "# X.on(obj) 操作，使用 X(obj) 代替\n",
    "\n",
    "X(0)         # OK\n",
    "# X.on(0)    # NOT OK"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
